#!/usr/bin/env python

import roslib; roslib.load_manifest('bwi_apps')
import rospy
import thread

import tf
from tf import TransformListener

from geometry_msgs.msg import Point
from geometry_msgs.msg import PointStamped
from bwi_msgs.msg import MultiLevelMapData
from bwi_msgs.msg import PersonDetectionArray
from bwi_msgs.msg import PersonDetection

import bwi_utils.utils
utils = bwi_utils.utils

class LocationAggregator:
  def __init__(self):
    rospy.init_node('listener', anonymous=True)

    try: 
      self.nodes = rospy.get_param('~nodes')
      print self.nodes
    except KeyError:
      rospy.logfatal("Did you forget to supply the node list (~nodes)")
      return

    self.out_msg = PersonDetectionArray()
    self.publish_rate = rospy.get_param('~publish_rate', 10)
    self.simulated = rospy.get_param("~simulated", True)
    self.out_topic = rospy.get_param('~out_topic', '/global/person_detections')
    self.global_frame = rospy.get_param("~global_frame", "/map")
    self.publisher = rospy.Publisher(self.out_topic, PersonDetectionArray)

    self.tf = TransformListener()

    #subscribe to map information and then get 
    self.map_data_available = False;
    self.map = None
    rospy.Subscriber("/map_metadata", MultiLevelMapData, self.mapCallback)

    rate = rospy.Rate(10.0)
    count = 0
    while (not rospy.is_shutdown()) and self.map == None:
      if (count == 10):
        rospy.loginfo("Waiting for map data")
        count = 0
      rate.sleep();
      count = count + 1
    rospy.loginfo("Map received")

    if self.simulated:
      self.obtainTransformsFromMapData()
      self.computeMapLimits()
      rospy.loginfo("Map data calculated")

    self.mutex = thread.allocate_lock()

    for node in self.nodes:
      topic = "%s/person_detections" % node
      rospy.loginfo("Subscribing to %s for node %s" %(topic, node))
      rospy.Subscriber(topic, PersonDetection, self.callback)

    self.spin()
    
  def callback(self, person):

    # map data has not yet been received. we can't locate what level a person is on
    if not self.map_data_available:
      return

    self.mutex.acquire()
    try:
      if self.simulated:
        person.level_id = self.computePersonLevel(person.feet)
        if person.level_id:
          frame_id = utils.frameIdFromLevelId(person.level_id)
          self.transformPoint(frame_id, person.feet, person.level_id)
      self.addPersonSafely(self.out_msg.detections, person)
    finally:
      self.mutex.release()

  def mapCallback(self, data):
    self.map = data

  def obtainTransformsFromMapData(self):
    self.transforms = dict()
    for level in self.map.levels:
      frame_id = utils.frameIdFromLevelId(level.level_id)
      rospy.loginfo("Attempting to get transform information %s -> %s" %(frame_id, self.global_frame))
      done = False
      rate = rospy.Rate(10)
      count = 0
      while not rospy.is_shutdown() and not done:
        try:
          (trans,rot) = self.tf.lookupTransform(self.global_frame, frame_id, rospy.Time(0))
          self.transforms[frame_id] = trans
          done = True
        except tf.Exception as e:
          rate.sleep()
          if count == 10:
            rospy.logwarn("Unable to get the transformation from %s. Trying again..." % frame_id)
            count = 0
          count = count + 1

  def computeMapLimits(self):
    self.map_corners = dict()

    for level in self.map.levels:
      level_id = level.level_id
      rospy.loginfo("getting frame for level '%s'" % level_id)
      frame_id = utils.frameIdFromLevelId(level_id)
      
      map_origin = level.info.origin
      map_width = level.info.width * level.info.resolution
      map_height = level.info.height * level.info.resolution
      
      #get origin corner
      point1 = PointStamped()
      point1.header.stamp = rospy.Time.now()
      point1.header.frame_id = frame_id
      point1.point.x = map_origin.position.x
      point1.point.y = map_origin.position.y
      point1.point.z = 0
      self.transformPoint(self.global_frame, point1)

      #get opposite corner
      point2 = PointStamped()
      point2.header.stamp = rospy.Time.now()
      point2.header.frame_id = frame_id
      point2.point.x = map_origin.position.x + map_width
      point2.point.y = map_origin.position.y + map_height
      #point2.point.z = 0
      self.transformPoint(self.global_frame, point2)

      #save map corners
      self.map_corners[level_id] = [point1, point2]
      self.map_data_available = True
      rospy.loginfo("computed map limits for level %s" % level_id)

  def transformPoint(self, target, point, source = None):
    if type(point) is PointStamped:
      source = point.header.frame_id
      point.header.frame_id = target
      point = point.point
    if target == self.global_frame:
      index = source
      multiplier = 1
    else:
      index = target
      multiplier = -1
    translation = self.transforms[index]
    point.x = point.x + multiplier * translation[0]
    point.y = point.y + multiplier * translation[1]
    point.z = point.z + multiplier * translation[2]
   
  def computePersonLevel(self, data):

    for level_id, corners in self.map_corners.iteritems():
      max_x = max(corners[0].point.x, corners[1].point.x) + 0.1
      min_x = min(corners[0].point.x, corners[1].point.x) - 0.1
      max_y = max(corners[0].point.y, corners[1].point.y) + 0.1
      min_y = min(corners[0].point.y, corners[1].point.y) - 0.1
      max_z = corners[0].point.z + 0.1
      min_z = corners[0].point.z - 0.1
      if data.z > min_z and data.z < max_z and data.y > min_y and data.y < max_y and data.x > min_x and data.x < max_x:
        return level_id

    return ""
    
  def addPersonSafely(self, list, object):
    object_id = None
    for i in range(len(list)):
      if list[i].id == object.id:
        object_id = i
    if object_id == None:
      list.append(object)
    else:
      del list[object_id]
      list.append(object)

  def spin(self):
    rate = rospy.Rate(self.publish_rate)
    while not rospy.is_shutdown():
      self.mutex.acquire()
      try:
        self.publisher.publish(self.out_msg)
        del self.out_msg.detections[:]
      finally:
        self.mutex.release()
      rate.sleep()

if __name__ == '__main__':
  location_aggregator = LocationAggregator()
